{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a86f218a-f1c3-4283-b432-5206639cd7ad",
   "metadata": {},
   "source": [
    "# **Using LLMs to Generate Synthetic Data for Fine-Tuning GLiNER**\n",
    "\n",
    "In this notebook, we'll explore a simple way to generate synthetic data for fine-tuning GLiNER. I have used a similar approach to generate training data for [**PII extraction**](https://huggingface.co/urchade/gliner_multi_pii-v1). We will be using `Mistral-7B-Instruct-v0.2`, though I think there are better LLMs available online (like LLaMa-3 ... etc).\n",
    "\n",
    "Additionally, the prompt used in this example is far from optimal, so you should adapt it to your specific use case or domain. This notebook serves only as an example for practitioners, as some people have requested one.\n",
    "\n",
    "In this notebook, we generate **fully synthetic data**, including both text and entity annotations, but if you have quality data from your target domain, *you can alternatively have the LLM annotate your existing data*. 📊📝\n",
    "\n",
    "Feel free to experiment and tailor the approach to better suit your needs! *Happy fine-tuning!* 🌟"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "acbcbda3-d359-4be8-9dd9-74b51d223fe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# install vllm (https://github.com/vllm-project/vllm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b48ed992-9c8a-4b80-b98f-ac6b018444ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "from vllm import LLM, SamplingParams"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36b416de-c46f-49e7-a30a-f115d15be1ca",
   "metadata": {},
   "source": [
    "## Load large language model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bad246e7-9c25-4510-ba67-5a399159c959",
   "metadata": {},
   "outputs": [],
   "source": [
    "LLM_MODEL = \"mistralai/Mistral-7B-Instruct-v0.2\" # you can use a better model\n",
    "NUM_GPUs = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8c56b1e1-ac43-45b2-ad3f-a763b53979d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING 05-10 21:18:17 config.py:549] Casting torch.bfloat16 to torch.float16.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-10 21:18:20,760\tINFO worker.py:1724 -- Started a local Ray instance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 05-10 21:18:25 llm_engine.py:72] Initializing an LLM engine with config: model='/gpfsdswork/dataset/HuggingFace_Models/mistralai/Mistral-7B-Instruct-v0.2', tokenizer='/gpfsdswork/dataset/HuggingFace_Models/mistralai/Mistral-7B-Instruct-v0.2', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=4, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, seed=0)\n",
      "INFO 05-10 21:19:02 llm_engine.py:322] # GPU blocks: 42735, # CPU blocks: 8192\n",
      "INFO 05-10 21:19:05 model_runner.py:632] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.\n",
      "INFO 05-10 21:19:05 model_runner.py:636] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\n",
      "\u001b[36m(RayWorkerVllm pid=44345)\u001b[0m INFO 05-10 21:19:05 model_runner.py:632] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.\n",
      "\u001b[36m(RayWorkerVllm pid=44345)\u001b[0m INFO 05-10 21:19:05 model_runner.py:636] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\n",
      "INFO 05-10 21:19:10 custom_all_reduce.py:199] Registering 2275 cuda graph addresses\n",
      "INFO 05-10 21:19:10 model_runner.py:698] Graph capturing finished in 6 secs.\n",
      "\u001b[36m(RayWorkerVllm pid=44345)\u001b[0m INFO 05-10 21:19:10 custom_all_reduce.py:199] Registering 2275 cuda graph addresses\n",
      "\u001b[36m(RayWorkerVllm pid=44345)\u001b[0m INFO 05-10 21:19:10 model_runner.py:698] Graph capturing finished in 6 secs.\n",
      "\u001b[36m(RayWorkerVllm pid=44502)\u001b[0m INFO 05-10 21:19:05 model_runner.py:632] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.\u001b[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)\u001b[0m\n",
      "\u001b[36m(RayWorkerVllm pid=44502)\u001b[0m INFO 05-10 21:19:05 model_runner.py:636] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
      "\u001b[36m(RayWorkerVllm pid=44429)\u001b[0m INFO 05-10 21:19:10 custom_all_reduce.py:199] Registering 2275 cuda graph addresses\n",
      "\u001b[36m(RayWorkerVllm pid=44429)\u001b[0m INFO 05-10 21:19:10 model_runner.py:698] Graph capturing finished in 6 secs.\n"
     ]
    }
   ],
   "source": [
    "llm = LLM(model=LLM_MODEL, tensor_parallel_size=NUM_GPUs, dtype=\"half\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f48628b5-30f2-4ad4-85c2-860ad3d9c715",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sampling parameters\n",
    "sampling_params = SamplingParams(top_k=100, max_tokens=1000, top_p=0.8, stop=\"<end>\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0985ea42-7050-4b0f-8e88-d37f50b88e89",
   "metadata": {},
   "source": [
    "## Prompting function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0b18ba43-c5ae-4583-b822-4cf262299260",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_json_prompt_for_synthetic_data(**kwargs):\n",
    "    \n",
    "    # Use dictionary comprehension to filter out 'n/a' values and to keep the code flexible\n",
    "    attributes = {key: value for key, value in kwargs.items() if value != \"n/a\"}\n",
    "    \n",
    "    # Building the initial part of the prompt\n",
    "    prompt = \"\"\"\n",
    "**Objective:**\n",
    "Produce realistic text passages that include clearly identified named entities. Each entity should be meticulously labeled according to its type for straightforward extraction.\n",
    "\n",
    "**Format Requirements:**\n",
    "- The output should be formatted in JSON, containing the text and the corresponding entities list.\n",
    "- Each entity in the text should be accurately marked and annotated in the 'entities' list.\n",
    "- Meticulously follow all the listed attributes.\n",
    "\n",
    "**Entity Annotation Details:**\n",
    "- All entity types must be in lowercase. For example, use \"type\" not \"TYPE\".\n",
    "- Entity types can be multiwords separate by space. For instance, use \"entity type\" rather than \"entity_type\".\n",
    "- Entities spans can be nested within other entities.\n",
    "- A single entity may be associated with multiple types. list them in the key \"types\".\n",
    "\n",
    "**Output Schema:**\n",
    "\n",
    "<start attribute_1=\"value1\" attribute_2=\"value2\" ...>\n",
    "{\n",
    "  \"text\": \"{text content}\",\n",
    "  \"entities\": [\n",
    "    {\"entity\": \"entity name\", \"types\": [\"type 1\", \"type 2\", ...]},\n",
    "    ...\n",
    "  ]\n",
    "}\n",
    "<end>\n",
    "\n",
    "**Here are some real world examples**:\"\"\"\n",
    "\n",
    "    # Create a string of attributes for the <start> tag, excluding any 'n/a' values\n",
    "    attributes_string = \" \".join([f'{key}=\"{value}\"' for key, value in attributes.items()])\n",
    "\n",
    "    # Adding the dynamically created attributes string to the prompt\n",
    "    prompt += f\"\"\"\n",
    "<start {attributes_string}>\n",
    "\"\"\"\n",
    "\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72bf847d-8065-408b-90bc-8490e2ad7ae8",
   "metadata": {},
   "source": [
    "## Example of generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "18567e19-cf16-4132-99a0-a084d17362ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "def generate(**kwargs):\n",
    "    outputs = llm.generate([create_json_prompt_for_synthetic_data(**kwargs)], sampling_params)\n",
    "    return json.loads(outputs[0].outputs[0].text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "81050597-eebc-4119-94c6-5bbebee9d369",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.40s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'text': \"Nous recherchons un Data Scientist expérimenté pour notre équipe de Paris. Votre mission consistera à concevoir et à mettre en œuvre des modèles statistiques et machine learning. Les candidats doivent posséder une solide expérience en Python et en R. Un diplôme universitaire dans le domaine des mathématiques ou de l'informatique est requis. Les meilleurs candidats auront également une bonne connaissance de TensorFlow et Scikit-Learn.\",\n",
       " 'entities': [{'entity': 'Nous', 'types': ['organization']},\n",
       "  {'entity': 'notre équipe', 'types': ['organization']},\n",
       "  {'entity': 'Paris', 'types': ['location']},\n",
       "  {'entity': 'Data Scientist', 'types': ['jobtitle']},\n",
       "  {'entity': 'votre mission', 'types': ['jobdescription']},\n",
       "  {'entity': 'concevoir et mettre en œuvre', 'types': ['jobresponsibility']},\n",
       "  {'entity': 'des modèles statistiques et machine learning',\n",
       "   'types': ['jobresponsibility']},\n",
       "  {'entity': 'Les candidats', 'types': ['person']},\n",
       "  {'entity': 'doivent posséder', 'types': ['requirement']},\n",
       "  {'entity': 'une solide expérience', 'types': ['requirement']},\n",
       "  {'entity': 'en Python et en R', 'types': ['requirement']},\n",
       "  {'entity': 'Un diplôme universitaire', 'types': ['requirement']},\n",
       "  {'entity': \"dans le domaine des mathématiques ou de l'informatique\",\n",
       "   'types': ['requirement']},\n",
       "  {'entity': 'Les meilleurs candidats', 'types': ['person']},\n",
       "  {'entity': 'ont également une bonne connaissance', 'types': ['requirement']},\n",
       "  {'entity': 'de TensorFlow et Scikit-Learn', 'types': ['requirement']}]}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generate(language=\"french\", types_of_text=\"detailled job ads\", sector=\"machine learning\", country=\"france\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e575fac-2b43-482e-904d-507c340c1a80",
   "metadata": {},
   "source": [
    "## Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d1d9b8ed-8c2e-410f-87e7-13392420300c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# post processing functions\n",
    "\n",
    "import re\n",
    "\n",
    "def tokenize_text(text):\n",
    "    \"\"\"Tokenize the input text into a list of tokens.\"\"\"\n",
    "    return re.findall(r'\\w+(?:[-_]\\w+)*|\\S', text)\n",
    "\n",
    "def extract_entities(data):\n",
    "    all_examples = []\n",
    "\n",
    "    for dt in data:\n",
    "\n",
    "        # Attempt to extract entities; skip current record on failure\n",
    "        try:\n",
    "            tokens = tokenize_text(dt['text'])\n",
    "            ents = [(k[\"entity\"], k[\"types\"]) for k in dt['entities']]\n",
    "        except:\n",
    "            continue\n",
    "\n",
    "        spans = []\n",
    "        for entity in ents:\n",
    "            entity_tokens = tokenize_text(str(entity[0]))\n",
    "\n",
    "            # Find the start and end indices of each entity in the tokenized text\n",
    "            for i in range(len(tokens) - len(entity_tokens) + 1):\n",
    "                if \" \".join(tokens[i:i + len(entity_tokens)]).lower() == \" \".join(entity_tokens).lower():\n",
    "                    for el in entity[1]:\n",
    "                        spans.append((i, i + len(entity_tokens) - 1, el.lower().replace('_', ' ')))\n",
    "\n",
    "        # Append the tokenized text and its corresponding named entity recognition data\n",
    "        all_examples.append({\"tokenized_text\": tokens, \"ner\": spans})\n",
    "\n",
    "    return all_examples\n",
    "\n",
    "# generation functions\n",
    "def generate_from_prompts(prompts, llm, sampling_params):\n",
    "    outputs = llm.generate(prompts, sampling_params)\n",
    "\n",
    "    all_outs = []\n",
    "    \n",
    "    for output in outputs:\n",
    "        try:\n",
    "            js = json.loads(output.outputs[0].text.strip())\n",
    "        except:\n",
    "            continue\n",
    "            \n",
    "        all_outs.append(js)\n",
    "\n",
    "    return all_outs, extract_entities(all_outs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "668162a9-16dd-487e-b20d-e001060136c4",
   "metadata": {},
   "source": [
    "## Use case: synthetic data for job ads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "34fd267d-fe82-460b-a9e0-9639ffe72432",
   "metadata": {},
   "outputs": [],
   "source": [
    "# I have used GPT-4 to generate these\n",
    "\n",
    "# List of countries\n",
    "countries = [\n",
    "    \"Madagascar\", \"Taiwan\", \"USA\", \"Germany\", \"France\", \"Spain\", \"Russia\", \"China\", \n",
    "    \"Japan\", \"Brazil\", \"India\", \"Egypt\", \"South Africa\", \"Australia\", \"Canada\", \n",
    "    \"Mexico\", \"Indonesia\", \"Nigeria\", \"Turkey\", \"United Kingdom\", \"Italy\", \"Poland\", \n",
    "    \"Argentina\", \"Netherlands\", \"Belgium\", \"Switzerland\", \"Sweden\", \"Norway\", \"Finland\",\n",
    "    \"Denmark\", \"Portugal\", \"Greece\", \"Iran\", \"Thailand\", \"Philippines\", \"Vietnam\", \n",
    "    \"South Korea\", \"Saudi Arabia\", \"Israel\", \"UAE\", \"New Zealand\", \"Ireland\", \"Malaysia\",\n",
    "    \"Singapore\", \"Hong Kong\", \"Czech Republic\", \"Hungary\", \"Romania\", \"Colombia\", \n",
    "    \"Peru\", \"Venezuela\", \"Chile\", \"Morocco\", \"Algeria\", \"Tunisia\", \"Nepal\", \"Pakistan\", \"Bangladesh\", \n",
    "    \"Kazakhstan\", \"Ukraine\", \"Austria\", \"Croatia\", \"Serbia\", \"Kenya\", \"Ghana\", \"Zimbabwe\",\n",
    "    \"Cuba\", \"Panama\", \"Fiji\", \"Mongolia\", \"North Korea\", \"Myanmar\", \"Ethiopia\", \"Tanzania\",\n",
    "    \"Algeria\", \"Libya\", \"Jordan\", \"Qatar\", \"Oman\", \"Kuwait\", \"Lebanon\", \"Bulgaria\", \"Slovakia\",\n",
    "    \"Lithuania\", \"Latvia\", \"Estonia\", \"Cyprus\", \"Luxembourg\", \"Macao\", \"Bhutan\", \"Maldives\",\n",
    "    \"Angola\", \"Cameroon\", \"Senegal\", \"Mali\", \"Zambia\", \"Uganda\", \"Namibia\", \"Botswana\",\n",
    "    \"Mozambique\", \"Ivory Coast\", \"Burkina Faso\", \"Malawi\", \"Gabon\", \"Lesotho\", \"Gambia\",\n",
    "    \"Guinea\", \"Cape Verde\", \"Rwanda\", \"Benin\", \"Burundi\", \"Somalia\", \"Eritrea\", \"Djibouti\",\n",
    "    \"Togo\", \"Seychelles\", \"Chad\", \"Central African Republic\", \"Liberia\", \"Mauritania\", \"Sri Lanka\",\n",
    "    \"Sierra Leone\", \"Equatorial Guinea\", \"Swaziland\", \"Congo (Kinshasa)\", \"Congo (Brazzaville)\"\n",
    "]\n",
    "\n",
    "# job sectors\n",
    "job_sectors = [\n",
    "    # Finance Sector Specializations\n",
    "    \"Investment Banking\",\n",
    "    \"Corporate Finance\",\n",
    "    \"Asset Management\",\n",
    "    \"Risk Management\",\n",
    "    \"Quantitative Analysis\",\n",
    "    \"Financial Planning\",\n",
    "    \n",
    "    # Machine Learning and AI Specializations\n",
    "    \"Natural Language Processing\",\n",
    "    \"Computer Vision\",\n",
    "    \"Deep Learning\",\n",
    "    \"Reinforcement Learning\",\n",
    "    \"Predictive Analytics\",\n",
    "    \"Algorithm Development\",\n",
    "    \n",
    "    # Healthcare Sector Specializations\n",
    "    \"Medical Research\",\n",
    "    \"Clinical Trials\",\n",
    "    \"Health Informatics\",\n",
    "    \"Biomedical Engineering\",\n",
    "    \"Public Health Administration\",\n",
    "    \"Pharmaceuticals\",\n",
    "    \n",
    "    # Education Sector Specializations\n",
    "    \"Curriculum Development\",\n",
    "    \"Educational Technology\",\n",
    "    \"Special Education\",\n",
    "    \"Higher Education Administration\",\n",
    "    \"Educational Policy\",\n",
    "    \"Language Instruction\",\n",
    "    \n",
    "    # Manufacturing Sector Specializations\n",
    "    \"Process Engineering\",\n",
    "    \"Quality Control\",\n",
    "    \"Industrial Design\",\n",
    "    \"Supply Chain Optimization\",\n",
    "    \"Robotics Manufacturing\",\n",
    "    \"Lean Manufacturing\",\n",
    "    \n",
    "    # Energy Sector Specializations\n",
    "    \"Renewable Energy Systems\",\n",
    "    \"Oil and Gas Exploration\",\n",
    "    \"Energy Efficiency Consulting\",\n",
    "    \"Nuclear Engineering\",\n",
    "    \"Smart Grid Technology\",\n",
    "    \"Energy Policy\",\n",
    "    \n",
    "    # Environmental Sector Specializations\n",
    "    \"Wildlife Conservation\",\n",
    "    \"Environmental Science\",\n",
    "    \"Water Resource Management\",\n",
    "    \"Sustainability Strategy\",\n",
    "    \"Climate Change Analysis\",\n",
    "    \"Environmental Law\",\n",
    "    \n",
    "    # Media and Communications Specializations\n",
    "    \"Digital Marketing\",\n",
    "    \"Journalism\",\n",
    "    \"Public Relations\",\n",
    "    \"Film Production\",\n",
    "    \"Broadcasting\",\n",
    "    \"Content Strategy\",\n",
    "    \n",
    "    # Legal Sector Specializations\n",
    "    \"Corporate Law\",\n",
    "    \"International Law\",\n",
    "    \"Intellectual Property\",\n",
    "    \"Environmental Law\",\n",
    "    \"Civil Litigation\",\n",
    "    \"Criminal Defense\",\n",
    "    \n",
    "    # Retail Sector Specializations\n",
    "    \"E-commerce Strategy\",\n",
    "    \"Store Management\",\n",
    "    \"Merchandise Planning\",\n",
    "    \"Customer Experience Management\",\n",
    "    \"Retail Analytics\",\n",
    "    \"Supply Chain Logistics\"\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "879ec8c9-ca1a-4ba6-a0ca-8170c405b225",
   "metadata": {},
   "source": [
    "### Generate prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8a501874-87b9-4eb0-8fa2-7ed357bcfb06",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create prompts\n",
    "NUM_SAMPLES = 100\n",
    "\n",
    "import random\n",
    "\n",
    "all_prompts = []\n",
    "\n",
    "for i in range(NUM_SAMPLES):\n",
    "    # sample\n",
    "    job_sector = random.choice(job_sectors)\n",
    "    country = random.choice(countries)\n",
    "    \n",
    "    prompt = create_json_prompt_for_synthetic_data(language=\"english\", \n",
    "                                                   types_of_text=\"detailled job ads\", \n",
    "                                                   sector=job_sector, \n",
    "                                                   country=country)\n",
    "    all_prompts.append(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3e7f996-2d77-413e-9aa1-2d94ecae79f7",
   "metadata": {},
   "source": [
    "### Generate outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d3ea1efa-e34c-4945-9197-a8a76bc259b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processed prompts: 100%|██████████| 100/100 [00:17<00:00,  5.60it/s]\n"
     ]
    }
   ],
   "source": [
    "output, processed_output = generate_from_prompts(all_prompts, llm, sampling_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "784a003b-973d-456e-812f-dc8931409394",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': 'Wanted: E-commerce Strategist in Lima, Peru. 5+ years of experience in digital marketing required. B2C or B2B projects preferred. Salary range: S/ 3000 to S/ 5000. Apply with resume and cover letter.',\n",
       " 'entities': [{'entity': 'E-commerce Strategist',\n",
       "   'types': ['person', 'jobtitle']},\n",
       "  {'entity': 'Lima, Peru', 'types': ['location']},\n",
       "  {'entity': '5+ years', 'types': ['quantity', 'duration']},\n",
       "  {'entity': 'digital marketing', 'types': ['skill']},\n",
       "  {'entity': 'B2C or B2B', 'types': ['business_model']},\n",
       "  {'entity': 'Salary range', 'types': ['salary']},\n",
       "  {'entity': 'S/ 3000 to S/ 5000', 'types': ['amount', 'currency']}]}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76345f9e-d9c1-4e2e-8335-72bdad80f2dd",
   "metadata": {},
   "source": [
    "### Some statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "3037612f-75f7-4485-9d54-ccb5dc9cd873",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Avg num tokens: 76.82291666666667\n"
     ]
    }
   ],
   "source": [
    "lengths = []\n",
    "\n",
    "for d in processed_output:\n",
    "    lengths.append(len(d[\"tokenized_text\"]))\n",
    "\n",
    "print(\"Avg num tokens:\", sum(lengths) / len(lengths))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "86f3e446-fff7-4fd6-b905-b1d2aad3f6ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Avg num of entities: 11.875\n"
     ]
    }
   ],
   "source": [
    "len_ner = []\n",
    "\n",
    "for d in processed_output:\n",
    "    len_ner.append(len(d[\"ner\"]))\n",
    "        \n",
    "print(\"Avg num of entities:\", sum(len_ner) / len(len_ner))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "1a6cea18-ad66-4f8f-ac1d-5063b6a846d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unique entity types: 1140\n"
     ]
    }
   ],
   "source": [
    "unique_entities = []\n",
    "\n",
    "for d in processed_output:\n",
    "    for n in d[\"ner\"]:\n",
    "        unique_entities.append((str(n[2]).lower()))\n",
    "\n",
    "print(\"Unique entity types:\", len(unique_entities))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1802cefc-6fb8-4ce7-9d03-49ac44833161",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('organization', 106),\n",
       " ('location', 86),\n",
       " ('job title', 83),\n",
       " ('person', 71),\n",
       " ('country', 41),\n",
       " ('technology', 40),\n",
       " ('field of study', 38),\n",
       " ('education', 29),\n",
       " ('degree', 24),\n",
       " ('quantity', 23)]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Top 10 entity types\n",
    "\n",
    "from collections import Counter\n",
    "Counter(unique_entities).most_common()[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4dde99a5-3fb7-453a-a541-3fb72ef8f4a2",
   "metadata": {},
   "source": [
    "### Save for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "22b57be1-ebca-4697-8100-fa911e84da7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save to JSON\n",
    "def save_data_to_file(data, filepath):\n",
    "    \"\"\"Saves the processed data to a JSON file.\"\"\"\n",
    "    with open(filepath, 'w') as f:\n",
    "        json.dump(data, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "05c03e6d-7148-4b83-a198-7c2890fe656c",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_file = \"job_ads_data_gliner.json\"\n",
    "\n",
    "save_data_to_file(processed_output, output_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b6b7131-4da1-43c1-8ab2-08f092fd4543",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
